import torch
X=torch.tensor([[0,1,2],[3,4,5],[6,7,8]])
X=X.reshape((1,1)+X.shape)
print(X.shape)


